﻿using System;
using RayDen.Library.Components.Surface;
using RayDen.Library.Core;
using RayDen.Library.Core.Primitives;
using RayDen.RayEngine.Core.Interface;
using RayDen.RayEngine.Core.Types;

namespace RayDen.RayEngine.Engines.BDPT
{
    public enum BDPTSamplerState
    {
        InitPath,
        PropagatePath,
        ConnectOnly,
    }

    public class BDPathSampler : PathSamplerBase
    {
        private const int MaxLightDepth = 4;
        protected BDPTSamplerState PathState;
        protected internal int LightRayIndex;
        protected internal RayData LightRay;
        protected RayEngineScene scene;
        protected SurfaceIntersectionData hitInfo = null;
        public int MaxRaysPerPath
        {
            get;
            set;
        }


        public RgbSpectrum LightThroughput, EyeThroughput;

        private VertexInfo[] lightPath;

        private ConnectRayInfo[] connectRays;

        private int eyeDepth, lightDepth, tracedConnectRayCount, lightVertices;
        private bool eyeStop, lightStop;
        private float eyePdf, lightPdf;

        public override void InitPath(IPathProcessor buffer)
        {
            base.InitPath(buffer);
            this.scene = pathIntegrator.Scene;

            this.Radiance = new RgbSpectrum(0f);
            this.PathState = BDPTSamplerState.InitPath;
            if (this.connectRays == null)
                this.connectRays = new ConnectRayInfo[scene.MaxPathDepth * scene.MaxPathDepth];
            lightPath = new VertexInfo[MaxLightDepth];
            this.Sample = pathIntegrator.Sampler.GetSample(null);
            lightVertices = 0;

            //Init eye ray
            this.EyeThroughput = new RgbSpectrum(1f);
            pathIntegrator.Scene.Camera.GetRay(Sample.imageX, Sample.imageY, out this.PathRay);
            this.RayIndex = -1;
            this.LightRayIndex = -1;
            this.eyePdf = 1f;

            //Init light path
            LightSample ls;
            var light = scene.Lights[scene.SampleLights(Sample.GetLazyValue())];
            light.EvaluatePhoton(scene, Sample.GetLazyValue(), Sample.GetLazyValue(), Sample.GetLazyValue(), Sample.GetLazyValue(), Sample.GetLazyValue(), out ls);
            LightRay = ls.LightRay;
            lightPdf = ls.Pdf * (1f / scene.Lights.Length);
            LightThroughput = (RgbSpectrum)(ls.Spectrum);


            this.tracedConnectRayCount = 0;
            this.eyeDepth = 0;
            this.lightDepth = 0;
            this.eyeStop = false;
            this.lightStop = false;
        }

        public override bool FillRayBuffer(RayBuffer rayBuffer)
        {
            var leftSpace = rayBuffer.LeftSpace();
            if (((PathState == BDPTSamplerState.InitPath) && (2 > leftSpace)) ||
            ((PathState == BDPTSamplerState.ConnectOnly) && (tracedConnectRayCount > leftSpace)) ||
            ((PathState == BDPTSamplerState.PropagatePath) && (tracedConnectRayCount + 2 > leftSpace)))
                return false;
            if (PathState != BDPTSamplerState.ConnectOnly)
            {
                RayIndex = rayBuffer.AddRay(ref PathRay);
                LightRayIndex = rayBuffer.AddRay(ref LightRay);
            }
            if (PathState == BDPTSamplerState.PropagatePath || PathState == BDPTSamplerState.ConnectOnly)
            {
                for (int i = 0; i < tracedConnectRayCount; ++i)
                    connectRays[i].ConnectRayIndex = rayBuffer.AddRay(ref connectRays[i].ConnectRay);
            }
            return true;
        }

        public override void Advance(RayBuffer rayBuffer, SampleBuffer consumer)
        {
            base.Advance(rayBuffer, consumer);

            if ((PathState == BDPTSamplerState.ConnectOnly || PathState == BDPTSamplerState.PropagatePath) &&
                tracedConnectRayCount > 0)
            {
                for (int i = 0; i < tracedConnectRayCount; ++i)
                {
                    RayHit shadowRayHit = rayBuffer.rayHits[connectRays[i].ConnectRayIndex];
                    RgbSpectrum attenuation;
                    if (this.ShadowRayTest(ref shadowRayHit, out attenuation))
                    {
                        //                            Radiance.MADD()
                        Radiance += attenuation * ((connectRays[i].Radiance) / connectRays[i].Pdf);
                        //pathWeight *= connectRays[i].pdf;
                    }
                }
                Splat(consumer);
                return;
            }

            if (!lightStop)
            {
                RayHit rayHit = rayBuffer.rayHits[LightRayIndex];
                lightDepth++;
                bool missed = rayHit.Index == 0xffffffffu;
                if (missed || lightDepth > MaxLightDepth)
                {
                    lightStop = true;
                    goto Eye;
                }
                // Something was hit
                if (hitInfo == null)
                {
                    hitInfo = SurfaceSampler.GetIntersection(ref LightRay, ref rayHit);
                }
                else
                {
                    SurfaceSampler.GetIntersection(ref LightRay, ref rayHit, ref hitInfo);
                }
                var hitPoint = LightRay.Point(rayHit.Distance);

                if (scene.IsLight((int)rayHit.Index))
                {
                    lightStop = true;
                    goto Eye;
                }

                Vector wo = -LightRay.Dir;
                var bsdf = hitInfo.MMaterial;
                float fPdf;
                Vector wi;
                bool specularBounce;
                hitInfo.Color = RgbSpectrum.UnitSpectrum();

                RgbSpectrum f = bsdf.Sample_f(ref wo,
                                                out wi,
                                                ref hitInfo.Normal,
                                                ref hitInfo.ShadingNormal, ref LightThroughput,
                                                Sample.GetLazyValue(),
                                                Sample.GetLazyValue(),
                                                Sample.GetLazyValue(),
                                                ref hitInfo.TextureData,
                                                out fPdf,
                                                out specularBounce);

                if (f.IsBlack() && fPdf <= 0f)
                {
                    this.lightStop = true;
                    goto Eye;
                }

                if (bsdf.IsDiffuse() && !f.IsBlack() && lightVertices < MaxLightDepth)
                {
                    this.lightPath[lightVertices].HitPoint = hitPoint;
                    this.lightPath[lightVertices].Wi = wo;
                    this.lightPath[lightVertices].GeoNormal = hitInfo.Normal;
                    this.lightPath[lightVertices].Throughput = LightThroughput;
                    this.lightPath[lightVertices].Pdf = this.lightPdf;
                    lightVertices++;
                }

                //this.lightPdf /= fPdf;
                this.LightThroughput *= f / fPdf;

                if (lightDepth > MaxLightDepth)
                {
                    this.lightStop = true;
                    goto Eye;
                }

                LightRay.Org = hitPoint;
                LightRay.Dir = wi.Normalize();
                PathState = BDPTSamplerState.PropagatePath;
            }

            // Eye path sampling
        Eye:
            if (!eyeStop)
            {
                RayHit rayHit = rayBuffer.rayHits[RayIndex];
                eyeDepth++;
                bool missed = rayHit.Index == 0xffffffffu;
                if (missed || eyeDepth > scene.MaxPathDepth)
                {
                    Radiance += this.SampleEnvironment(PathRay.Dir) * EyeThroughput;
                    Splat(consumer);
                    return;
                }

                // Something was hit
                if (hitInfo == null)
                {
                    hitInfo = SurfaceSampler.GetIntersection(ref PathRay, ref rayHit);
                }
                else
                {
                    SurfaceSampler.GetIntersection(ref PathRay, ref rayHit, ref hitInfo);
                }
                var currentTriangleIndex = (int)rayHit.Index;
                if (hitInfo.IsLight)
                {
                    //if (specularBounce) 
                    {
                        var lt = scene.GetLightByIndex(currentTriangleIndex);
                        if (lt != null)
                        {
                            var le = (RgbSpectrum)(lt.Le(ref PathRay.Dir));
                            Radiance += EyeThroughput * le;
                        }
                    }
                    Splat(consumer);

                    return;
                }
                var hitPoint = PathRay.Point(rayHit.Distance);

                Vector wo = -PathRay.Dir;
                var bsdf = hitInfo.MMaterial;
                float fPdf;
                Vector wi;
                bool specularBounce;
                if (bsdf.IsDiffuse())
                {
                    float lightStrategyPdf = scene.ShadowRayCount / (float)scene.Lights.Length;
                    RgbSpectrum lightTroughtput = EyeThroughput * hitInfo.Color;
                    for (int i = 0; i < scene.ShadowRayCount; ++i)
                    {
                        int currentLightIndex = scene.SampleLights(Sample.GetLazyValue());
                        var light = scene.Lights[currentLightIndex];

                        var ls = new LightSample();
                        light.EvaluateShadow(ref hitPoint,
                                             ref hitInfo.ShadingNormal,
                                             Sample.GetLazyValue(),
                                             Sample.GetLazyValue(),
                                             Sample.GetLazyValue(),
                                             ref ls);
                        if (ls.Pdf <= 0f)
                        {
                            continue;
                        }

                        connectRays[tracedConnectRayCount].Radiance = (RgbSpectrum)(ls.Spectrum);
                        connectRays[tracedConnectRayCount].Pdf = ls.Pdf;
                        connectRays[tracedConnectRayCount].ConnectRay = ls.LightRay;
                        connectRays[tracedConnectRayCount].Direct = true;

                        Vector mwi = connectRays[tracedConnectRayCount].ConnectRay.Dir;
                        Vector lwi = connectRays[tracedConnectRayCount].ConnectRay.Dir;
                        RgbSpectrum fs;

                        hitInfo.MMaterial.f(
                            ref mwi,
                            ref wo,
                            ref hitInfo.ShadingNormal, ref lightTroughtput, 
                            out fs,
                            types: BrdfType.Diffuse);
                        connectRays[tracedConnectRayCount].Radiance *= lightTroughtput *
                                                                       Vector.AbsDot(ref hitInfo.ShadingNormal, ref lwi) *
                                                                       fs;
                        if (!connectRays[tracedConnectRayCount].Radiance.IsBlack())
                        {
                            connectRays[tracedConnectRayCount].Pdf *= lightStrategyPdf;
                            tracedConnectRayCount++;
                        }
                    }

                    for (int i = 0; i < lightVertices; i++)
                    {

                        var lightThroughtput = lightPath[i].Throughput;

                        connectRays[tracedConnectRayCount].Radiance = EyeThroughput * hitInfo.Color;
                        connectRays[tracedConnectRayCount].Pdf = 1f;
                        connectRays[tracedConnectRayCount].ConnectRay = new RayData(hitPoint, (hitPoint - lightPath[i].HitPoint).Normalize());
                        var lwi = -connectRays[tracedConnectRayCount].ConnectRay.Dir;


                        connectRays[tracedConnectRayCount].Radiance *= lightThroughtput *
                                                                       Geometry.G(ref hitPoint, ref lightPath[i].HitPoint, ref hitInfo.Normal, ref lightPath[i].GeoNormal);
                        //     Vector.AbsDot(ref hitInfo.ShadingNormal, ref lwi);
                        if (!connectRays[tracedConnectRayCount].Radiance.IsBlack())
                        {
                            connectRays[tracedConnectRayCount].Pdf *= lightPath[i].Pdf*eyePdf;
                            tracedConnectRayCount++;
                        }
                    }

                }

                RgbSpectrum f = bsdf.Sample_f(ref wo,
                                                  out wi,
                                                  ref hitInfo.Normal,
                                                  ref hitInfo.ShadingNormal, ref EyeThroughput,
                                                  Sample.GetLazyValue(),
                                                  Sample.GetLazyValue(),
                                                  Sample.GetLazyValue(),
                                                  ref hitInfo.TextureData,
                                                  out fPdf,
                                                  out specularBounce);
                if ((fPdf <= 0.0f) || f.IsBlack())
                {
                    if (tracedConnectRayCount > 0)
                        PathState = BDPTSamplerState.ConnectOnly;
                    else
                    {
                        //Splat(consumer);
                        eyeStop = true;
                    }
                    eyeStop = true;
                    goto Connect;
                }

                eyePdf *= fPdf;
                EyeThroughput *= (f) / fPdf;
                if (eyeDepth > scene.MaxPathDepth)
                {
                    float prob = Math.Max(EyeThroughput.Filter(), scene.RussianRuletteImportanceCap);
                    if (prob >= Sample.GetLazyValue())
                    {
                        EyeThroughput /= prob;
                        eyePdf *= prob;
                    }
                    else
                    {
                        if (tracedConnectRayCount > 0)
                            PathState = BDPTSamplerState.ConnectOnly;
                        else
                        {
                            eyeStop = true;
                            goto Connect;
                        }

                        eyeStop = true;
                        goto Connect;
                    }
                }

                PathRay.Org = hitPoint;
                PathRay.Dir = wi.Normalize();
                PathState = BDPTSamplerState.PropagatePath;
            }

        @Connect:
            if (eyeStop && lightStop)
            {
                this.Splat(consumer);
                return;
            }
        }

        private bool ShadowRayTest(ref RayHit shadowRayHit, out RgbSpectrum attenuation)
        {
            var hit = shadowRayHit.Index == 0xffffffffu;// || scene.IsLight((int) shadowRayHit.Index);
            attenuation = new RgbSpectrum(1f);
            return hit;
        }

        private RgbSpectrum SampleEnvironment(Vector vector)
        {
            return this.scene.SampleEnvironment(-vector);
        }


        protected struct VertexInfo
        {
            public Point HitPoint;
            public Normal GeoNormal;
            public RgbSpectrum Throughput;
            public Vector Wi;
            public float Pdf;
        }

        struct ConnectRayInfo
        {
            public RgbSpectrum Radiance;
            public float Pdf;
            public RayData ConnectRay;
            public int ConnectRayIndex;
            public bool Direct;
        }

    }
}
